from pyspark.ml.feature import PCA, PCAModel
from utils.spark_util import col_validation, one_hot_encoding, feature_cols_labelling, \
    one_hot_encoding_p, get_model_info, pca_to_df


def predict(sc, spark, dataframe, para_list, output):
    # --------------------prepare feature data------------------
    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])

    model_save_path, feature_list_train = get_model_info(sc, para_list["model_id"])

    cat_col_dict = {}

    i = -1
    for col in feature_list_train:
        if "_OneHot/" in col:
            feature, category = col.split("_OneHot/")
            if feature in cat_col_dict.keys():
                cat_col_dict[cat_cols[i]].append(category)
            else:
                i += 1
                cat_col_dict[cat_cols[i]] = [category]

    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding_p(sc, dataframe, para_list["source"], cat_cols
                                                                  , cat_col_dict)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols

    df = feature_cols_labelling(dataframe, feature_list)

    # --------------------load the model-----------------

    model = PCAModel.load(model_save_path)

    # ---------------------make predictions-----------------
    df_pred = pca_to_df(model.transform(df), para_list["k"], spark, feature_list)

    for col in cat_cols:
        df_pred = df_pred.drop(col)

    return df_pred


def train(sc, spark, dataframe, para_list, record, output):
    print("START TRAINING PCA")
    # --------------------prepare data------------------
    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])

    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding(sc, dataframe, para_list["source"], cat_cols)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols
    df = feature_cols_labelling(dataframe, feature_list)

    # --------------------model fit----------------------
    print("MODEL FITTING")
    k = para_list["k"]
    pca = PCA(k=k, inputCol="features", outputCol="pca_features")

    model = pca.fit(df)
    print("MODEL FITTED")
    # ---------------------get metrics---------------

    explainedVariance = model.explainedVariance
    pcMatrix = model.pc
    print("PC MATRIX")
    # ----------------------record info-------------------
    predictions = model.transform(df)
    print("TRANSFORMED")
    df_pred = pca_to_df(predictions, k, spark, feature_list)
    for col in cat_cols:
        df_pred = df_pred.drop(col)

    explained_variance = explainedVariance.toArray().tolist()
    ev_chart = []
    sum = 0
    for i in range(k):
        sum = sum + explained_variance[i]
        sub_list = [i + 1, explained_variance[i], sum]
        ev_chart.append(sub_list)
    other_info = {
        "feature_list": feature_list,
        "feature_col": para_list["feature_col"],
        "ev_chart": ev_chart,
        "pcMatrix": pcMatrix.toArray().tolist()
    }
    record["other_info"] = other_info
    print("RETURN")
    return model
